{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# 4-bit Quantization of GPT-2 with Auto-GPTQ\n",
        "This notebook is a companion of chapter 5 of the \"Domain Specific LLMs in Action\" book, author Guglielmo Iozzia, [Manning Publications](https://www.manning.com/), 2024.  \n",
        "The code in this notebook is to introduce readers to 4-bit quantization of a decoder-only language model, [GPT-2](https://huggingface.co/openai-community/gpt2), using the [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ) library. It requires hardware acceleration.  \n",
        "More details about the code can be found in the related book's chapter."
      ],
      "metadata": {
        "id": "8XxP44PHQUvB"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Install the missing dependecies (AutGPTQ only)."
      ],
      "metadata": {
        "id": "We_YmB7LQ7Nf"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BhufqqQAaz6e"
      },
      "outputs": [],
      "source": [
        "!export BUILD_CUDA_EXT=0\n",
        "!pip install -q auto-gptq"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Force the upgrade to the latest HF's Dataset package. A runtime restart would be probably needed when completed."
      ],
      "metadata": {
        "id": "_9wyW462Cblt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install --force-reinstall datasets"
      ],
      "metadata": {
        "id": "rVh_suZSBByc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Import the required classes/packages"
      ],
      "metadata": {
        "id": "Vn0PGpphRCsK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import random\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "from datasets import load_dataset\n",
        "from transformers import TextGenerationPipeline\n",
        "\n",
        "from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig"
      ],
      "metadata": {
        "id": "mXM8TuXx60Af"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Define a function to load and prepare the test set to be used for model quantization and validation."
      ],
      "metadata": {
        "id": "TPoQu-aORKPo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def get_wikitext2(nsamples, seed, seqlen, tokenizer):\n",
        "    # set seed\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.random.manual_seed(seed)\n",
        "\n",
        "    # load dataset and preprocess\n",
        "    traindata = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"train\")\n",
        "    testdata = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"test\")\n",
        "    trainenc = tokenizer(\"\\n\\n\".join(traindata[\"text\"]), return_tensors=\"pt\")\n",
        "    testenc = tokenizer(\"\\n\\n\".join(testdata[\"text\"]), return_tensors=\"pt\")\n",
        "\n",
        "    traindataset = []\n",
        "    for _ in range(nsamples):\n",
        "        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)\n",
        "        j = i + seqlen\n",
        "        inp = trainenc.input_ids[:, i:j]\n",
        "        attention_mask = torch.ones_like(inp)\n",
        "        traindataset.append({\"input_ids\": inp, \"attention_mask\": attention_mask})\n",
        "    return traindataset, testenc"
      ],
      "metadata": {
        "id": "7v5B-ZQ9bFY-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Specify the model ID in the HF's Hub and the destination directory where to save the quantized model."
      ],
      "metadata": {
        "id": "7QIwiY77RWwi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model_id = \"openai-community/gpt2\"\n",
        "quantized_model_dir = \"gpt-2-4bit\""
      ],
      "metadata": {
        "id": "0gpUPSLx64VY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download the model tokenizer from the HF's Hub."
      ],
      "metadata": {
        "id": "tW_4tyJJRft3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoTokenizer\n",
        "\n",
        "try:\n",
        "    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)\n",
        "except Exception:\n",
        "    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)"
      ],
      "metadata": {
        "id": "j7ccCFIj66RG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Set the quantization configuration."
      ],
      "metadata": {
        "id": "LkoWAXj8RrQi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "quantize_config = BaseQuantizeConfig(\n",
        "    bits=4,\n",
        "    group_size=128,\n",
        "    desc_act=False,\n",
        ")"
      ],
      "metadata": {
        "id": "ZjdUlqdh6_-b"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download the un-quantized model (it is forced to be loaded into CPU). Then get the maximum sequence lenght for it."
      ],
      "metadata": {
        "id": "2tuQOxD5RvXM"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model = AutoGPTQForCausalLM.from_pretrained(model_id, quantize_config)\n",
        "model_config = model.config.to_dict()\n",
        "seq_len_keys = [\"max_position_embeddings\", \"seq_length\", \"n_positions\"]\n",
        "if any(k in model_config for k in seq_len_keys):\n",
        "    for key in seq_len_keys:\n",
        "        if key in model_config:\n",
        "            model.seqlen = model_config[key]\n",
        "            break\n",
        "else:\n",
        "    print(\"The model's sequence length cannot be retrieved from its configuration. It will then be set to 2048.\")\n",
        "    model.seqlen = 2048"
      ],
      "metadata": {
        "id": "vyYnnSGdbgj3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Load and prepare the dataset for the quantization process."
      ],
      "metadata": {
        "id": "nJTiZV_fSBJB"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "traindataset, testenc = get_wikitext2(128, 0, model.seqlen, tokenizer)"
      ],
      "metadata": {
        "id": "_uDvU4epblHJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Quantize the model. The examples used should be a list of dict whose keys contains \"input_ids\" and \"attention_mask\"."
      ],
      "metadata": {
        "id": "B2d4N36XSIXt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model.quantize(traindataset, use_triton=False)"
      ],
      "metadata": {
        "id": "GxjJUWR27Dw0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Save the quantized model to disk."
      ],
      "metadata": {
        "id": "uybeFbU4SWOp"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model.save_quantized(quantized_model_dir, use_safetensors=True)"
      ],
      "metadata": {
        "id": "Jp1BxVlY7KLC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The size of the saved safetensors is 1.02 GB."
      ],
      "metadata": {
        "id": "S-ug6e4mhbFH"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Load the quantized model."
      ],
      "metadata": {
        "id": "3S_i8roH7SH0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "quantized_model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,\n",
        "                                           device=\"cuda:0\", use_triton=False)"
      ],
      "metadata": {
        "id": "htdp8my37NJY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Do inference with the quantized model."
      ],
      "metadata": {
        "id": "iY1IUgluSsHa"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = \"Auto-GPTQ is\"\n",
        "output = tokenizer.decode(\n",
        "    quantized_model.generate(**tokenizer(prompt, return_tensors=\"pt\").to(\"cuda:0\"))[0])\n",
        "print(output)"
      ],
      "metadata": {
        "id": "RymnhESo7RmI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "HF Transformers pipelines are supported too for inference with the 4-bit quantized model."
      ],
      "metadata": {
        "id": "5whCVKeZSuvf"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "pipeline = TextGenerationPipeline(model=model, tokenizer=tokenizer, device=\"cuda:0\")\n",
        "print(pipeline(prompt)[0][\"generated_text\"])"
      ],
      "metadata": {
        "id": "61W5tKt4_3V1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Weight Comparison.\n",
        "The following 5 code cells are meant to display the weights of the original and the 4-bit quantized model in a histogram chart, same way as for the 8-bit quantization case presented in the CH05_NB02_Iozzia.ipynb notebook. Please refer to it for more details."
      ],
      "metadata": {
        "id": "mejncb-TvQdk"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model = AutoGPTQForCausalLM.from_pretrained(model_id, quantize_config)"
      ],
      "metadata": {
        "id": "lfS-dS3KyjEH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "weights = [param.data.clone() for param in model.parameters()]\n",
        "weights_int8 = [param.data.clone() for param in quantized_model.parameters()]"
      ],
      "metadata": {
        "id": "DVdvfHYdvVNE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "weights = np.concatenate([t.cpu().numpy().flatten() for t in weights])\n",
        "weights_int8 = np.concatenate([t.cpu().numpy().flatten() for t in weights_int8])"
      ],
      "metadata": {
        "id": "_a5SLYL7wXbx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.ticker as ticker"
      ],
      "metadata": {
        "id": "gWPx93LBv2oH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Set background style\n",
        "plt.style.use('ggplot')\n",
        "\n",
        "# Create figure and axes\n",
        "fig, axs = plt.subplots(1, figsize=(10,10), dpi=300, sharex=True)\n",
        "\n",
        "# Plot the histograms for original and zero-point weights\n",
        "axs.hist(weights, bins=150, alpha=0.5, label='Original weights', color='yellow', range=(-0.5, 0.5))\n",
        "axs.hist(weights_int8, bins=150, alpha=0.5, label='LLM.int8() weights', color='blue', range=(-0.5, 0.5))\n",
        "\n",
        "# Add grid\n",
        "axs.grid(True, linestyle='--', alpha=0.6)\n",
        "\n",
        "# Add legend\n",
        "axs.legend()\n",
        "\n",
        "# Add title and labels\n",
        "axs.set_title('Comparison of Original and LLM.int8() Weights', fontsize=16)\n",
        "\n",
        "axs.set_xlabel('Weights', fontsize=14)\n",
        "axs.set_ylabel('Count', fontsize=14)\n",
        "axs.yaxis.set_major_formatter(ticker.EngFormatter()) # Make y-ticks more human readable\n",
        "\n",
        "# Improve font\n",
        "plt.rc('font', size=12)\n",
        "\n",
        "plt.tight_layout()\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "cWXUyGEhwI7B"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}